#include "torch_npu/csrc/distributed/PrefixStore.hpp"
#include <utility>

namespace c10d_npu {
    PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<c10d::Store> store)
        : prefix_(std::move(prefix)), store_(std::move(store)) {}

    std::string PrefixStore::joinKey(const std::string &key)
    {
        return prefix_ + "/" + key;
    }

    std::vector<std::string> PrefixStore::joinKeys(
        const std::vector<std::string> &keys)
    {
        std::vector<std::string> joinedKeys;
        joinedKeys.reserve(keys.size());
        for (const auto &key : keys) {
            joinedKeys.emplace_back(joinKey(key));
        }
        return joinedKeys;
    }

    void PrefixStore::set(
        const std::string &key,
        const std::vector<uint8_t> &value)
    {
        store_->set(joinKey(key), value);
    }

    std::vector<uint8_t> PrefixStore::compareSet(
        const std::string &key,
        const std::vector<uint8_t> &expectedValue,
        const std::vector<uint8_t> &desiredValue)
    {
        return store_->compareSet(joinKey(key), expectedValue, desiredValue);
    }

    std::vector<uint8_t> PrefixStore::get(const std::string &key)
    {
        return store_->get(joinKey(key));
    }

    int64_t PrefixStore::add(const std::string &key, int64_t value)
    {
        return store_->add(joinKey(key), value);
    }

    bool PrefixStore::deleteKey(const std::string &key)
    {
        return store_->deleteKey(joinKey(key));
    }

    int64_t PrefixStore::getNumKeys()
    {
        return store_->getNumKeys();
    }

    bool PrefixStore::check(const std::vector<std::string> &keys)
    {
        auto joinedKeys = joinKeys(keys);
        return store_->check(joinedKeys);
    }

    void PrefixStore::wait(const std::vector<std::string> &keys)
    {
        auto joinedKeys = joinKeys(keys);
        store_->wait(joinedKeys);
    }

    void PrefixStore::wait(
        const std::vector<std::string> &keys,
        const std::chrono::milliseconds &timeout)
    {
        auto joinedKeys = joinKeys(keys);
        store_->wait(joinedKeys, timeout);
    }

    const std::chrono::milliseconds &PrefixStore::getTimeout() const noexcept
    {
        return store_->getTimeout();
    }

    void PrefixStore::setTimeout(const std::chrono::milliseconds &timeout)
    {
        store_->setTimeout(timeout);
    }

    void PrefixStore::append(
        const std::string &key,
        const std::vector<uint8_t> &value)
        {
        store_->append(joinKey(key), value);
    }

    std::vector<std::vector<uint8_t>> PrefixStore::multiGet(
        const std::vector<std::string> &keys)
    {
        std::vector<std::string> prefixed_keys;
        prefixed_keys.reserve(keys.size());
        for (auto &key : keys) {
            prefixed_keys.push_back(joinKey(key));
        }
        return store_->multiGet(prefixed_keys);
    }

    void PrefixStore::multiSet(
        const std::vector<std::string> &keys,
        const std::vector<std::vector<uint8_t>> &values)
    {
        std::vector<std::string> prefixed_keys;
        prefixed_keys.reserve(keys.size());
        for (auto &key : keys) {
            prefixed_keys.push_back(joinKey(key));
        }
        store_->multiSet(prefixed_keys, values);
    }

    // Returns true if this store support append, multiGet and multiSet
    bool PrefixStore::hasExtendedApi() const
    {
        return store_->hasExtendedApi();
    }

    c10::intrusive_ptr<c10d::Store> PrefixStore::getUnderlyingStore()
    {
        return store_;
    }

    c10::intrusive_ptr<c10d::Store> PrefixStore::getUnderlyingNonPrefixStore()
    {
        c10::intrusive_ptr<c10d::Store> store = store_;

        while (store) {
            // Attempt to dynamically cast to PrefixStore
            PrefixStore *asPrefixStore = dynamic_cast<PrefixStore *>(store.get());
            if (asPrefixStore) {
                store = asPrefixStore->getUnderlyingStore();
            } else {
                break; // We've reached a non-PrefixStore
            }
        }

        TORCH_CHECK(
            store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
        return store;
    }

} // namespace c10d
